{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "import torch\n",
    "from torch.distributions.multinomial import Multinomial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multinomial Distribution\n",
    "\n",
    "In the aforementioned example, we considered a dataset with only two classes - celebrities and non-celebrities. Now let us say we have a more granular dataset as follows:\n",
    "<table>\n",
    "    <tr>\n",
    "        <td> Index <td> Class <td> Percentage <td> Probability\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td> 1 <td> Albert Einstein <td> 10% <td> 0.1\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td> 2 <td> Marie Curie <td> 42% <td> 0.42\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td> 3 <td> Gauss <td> 4% <td> 0.04\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td> 4 <td> Others <td> 44 % <td> 0.44\n",
    "    <tr>\n",
    "<table>\n",
    "\n",
    "    \n",
    "Let us perform an experiment where we select 10 photos from the dataset and want to find out the probability that class1 occurs 1 time, class2 occurs 2 times, class3 occurs 1 time and class 4 occurs the remaining 6 times. We can do so using the multinomial distribution, which is an extension of the binomial distribution from 2 variables to m variables.\n",
    "    \n",
    "Formally, let $C_{1}$, $C_{2}$, ..., $C_{m}$ be $m$ classes with probabilities $p_{1}$, $p_{2}$, ..., $p_{m}$. Let $X_{1}$, $X_{2}$, ..., $X_{m}$ be the corresponding random variables in a set of $n$ trials.\n",
    "    \n",
    "Then, the multinomial probability function, depicting probability of $C_{1}$ being selected $k_{1}$ times, $C_{2}$ being selected $k_{2}$ times, $C_{m}$ being selected $k_{m}$ times is \n",
    "   $$ P(X_{1}=k_{1}, X_{2}=k_{2}, .., X_{m}=k_{m}) = \\frac{n!}{k_{1}!k_{2}!...k_{m}!} p_{1}^{k_{1}}p_{2}^{k_{2}}...p_{m}^{k_{m}} $$\n",
    "    \n",
    "where $ \\sum_{i=1}^m k_i = n$ and $ \\sum_{i=1}^m p_i = 1$\n",
    "    \n",
    "Note that for $m=2$, this becomes the binomial distribution. Now, let us implement this in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the parameters of the distribution\n",
    "num_trials = 10\n",
    "p = torch.tensor([0.1, 0.42, 0.04, 0.44], dtype=torch.float)\n",
    "\n",
    "# Instantiate the multinomial distribution\n",
    "multinomial_dist = Multinomial(num_trials, probs=p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -4.350330829620361\n",
      "Raw eval Log Prob: -4.3503313064575195\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([1, 2, 1, 6], dtype=torch.float)\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, n, p):\n",
    "    f = math.factorial\n",
    "    result = f(n)\n",
    "    for pi, xi in zip(p, X):\n",
    "        result *= (pi ** xi) / f(xi)\n",
    "    return torch.log(result)\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = multinomial_dist.log_prob(X)\n",
    "print(\"Log Prob: {}\".format(log_prob))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, num_trials, p)\n",
    "print(\"Raw eval Log Prob: {}\".format(raw_eval_log_prob))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of experiment runs\n",
    "num_samples = 100000 \n",
    "\n",
    "# Draw Samples. Each element of the samples array represent the number of successes in that experiment.\n",
    "samples = multinomial_dist.sample(torch.Size([num_samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample mean: tensor([0.9966, 4.2037, 0.3987, 4.4010])\n",
      "Dist Mean: tensor([1.0000, 4.2000, 0.4000, 4.4000])\n",
      "Sample variance: tensor([0.9051, 2.4231, 0.3830, 2.4402])\n",
      "Dist Variance: tensor([0.9000, 2.4360, 0.3840, 2.4640])\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples. Denotes the average number of successes\n",
    "sample_mean = samples.mean(axis=0)\n",
    "print(\"Sample mean: {}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = multinomial_dist.mean\n",
    "print(\"Dist Mean: {}\".format(dist_mean))\n",
    "\n",
    "# As expected, the two means approximately match and are equal to [num_trials * p1, num_trials * p2, num_trials * pm]\n",
    "assert torch.allclose(sample_mean, dist_mean, atol=1e-1)\n",
    "\n",
    "# The variance obtained from the samples\n",
    "sample_var = multinomial_dist.sample([num_samples]).var(axis=0)\n",
    "print(\"Sample variance: {}\".format(sample_var))\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = multinomial_dist.variance\n",
    "print(\"Dist Variance: {}\".format(dist_var))\n",
    "\n",
    "# As expected, the two variances approximately match.\n",
    "assert torch.allclose(sample_var, dist_var, atol=1e-1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
